﻿using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Hosting;
using Microsoft.Extensions.Logging;
using Rougamo;
using Rougamo.Context;
using System;
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using System.Reflection;
using System.Text;
using System.Threading.Tasks;

namespace XC.Framework.Aop.Attributes
{
    public class AopMoAttribute : MoAttribute
    {
        static IHost? Host = null;
        static readonly AsyncLocal<IServiceProvider> ServiceProvider = new();
        static readonly AsyncLocal<IServiceScope?> Scope = new();

        private static readonly object _lock = new();

        /// <summary>
        /// 设置主机
        /// </summary>
        /// <param name="host"></param>
        public static void SetHost(IHost host)
        {
            if (Host is not null) return;
            Host = host;
        }

        /// <summary>
        /// 设置服务提供者
        /// </summary>
        /// <param name="serviceProvider"></param>
        public static void SetServiceProvider(IServiceProvider serviceProvider) => ServiceProvider.Value = serviceProvider;

        /// <summary>
        /// 创建域服务
        /// </summary>
        /// <returns></returns>
        protected static IServiceScope? GetScope()
        {
            lock (_lock)
            {
                if (Scope.Value is not null)
                {
                    Scope.Value?.Dispose();
                    Scope.Value = null;
                }

                Scope.Value ??= Host?.Services?.CreateScope();
                return Scope.Value;
            }
        }

        /// <summary>
        /// 获取服务
        /// </summary>
        /// <param name="type"></param>
        /// <param name="context"></param>
        /// <returns></returns>
        protected object? GetService(MethodContext context, Type type)
        {
            try
            {
                var serviceProvider = GetServiceProvider(context);
                if (serviceProvider != null)
                {
                    var service = serviceProvider.GetService(type);
                    if (service is not null)
                    {
                        return service;
                    }
                }

                if (ServiceProvider.Value != null)
                {
                    return ServiceProvider.Value.GetService(type);
                }

                return GetScope()?.ServiceProvider.GetService(type);
            }
            catch (Exception)
            {
                return GetScope()?.ServiceProvider.GetService(type);
            }
        }

        /// <summary>
        /// 获取服务
        /// </summary>
        /// <typeparam name="TService"></typeparam>
        /// <returns></returns>
        protected TService? GetService<TService>(MethodContext context)
        {
            try
            {
                var serviceProvider = GetServiceProvider(context);
                if (serviceProvider != null)
                {
                    var service = serviceProvider.GetService<TService>();
                    if (service is not null)
                    {
                        return service;
                    }
                }

                if (ServiceProvider.Value != null)
                {
                    return ServiceProvider.Value.GetService<TService>();
                }

                var scope = GetScope();

                return scope is null ? default : scope!.ServiceProvider.GetService<TService>();
            }
            catch (Exception)
            {
                var scope = GetScope();

                return scope is null ? default : scope!.ServiceProvider.GetService<TService>();
            }
        }

        /// <summary>
        /// 获取日志对象
        /// </summary>
        /// <typeparam name="T"></typeparam>
        /// <param name="context"></param>
        /// <returns></returns>
        protected ILogger? GetLogger<T>(MethodContext context)
        {
            try
            {
                // 获取日志对象
                var type = typeof(ILogger<>).MakeGenericType(context.TargetType!);
                return GetService(context, type) as ILogger;
            }
            catch (Exception)
            {
                // 获取日志对象
                var type = typeof(ILogger<>).MakeGenericType(typeof(T));
                return GetService(context, type) as ILogger;
            }
        }

        /// <summary>
        /// 获取 IServiceProvider
        /// </summary>
        /// <param name="context"></param>
        /// <param name="type"></param>
        /// <returns></returns>
        protected virtual IServiceProvider? GetServiceProvider(MethodContext context, Type? type = null)
        {
            // 检测类型是否继承了 IAutowiredServiceProvider
            if (context.Target is IAopServiceProvider autowiredServiceProvider)
            {
                return autowiredServiceProvider.ServiceProvider;
            }

            var flags = BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance | BindingFlags.Static;

            IServiceProvider? serviceProvider = null;

            try
            {
                type ??= context.Target?.GetType();

                if (type is null) return serviceProvider;

                foreach (var item in type.GetProperties(flags))
                {
                    if (item.PropertyType == typeof(IServiceProvider))
                    {
                        serviceProvider = item.GetValue(context.Target) as IServiceProvider;
                        break;
                    }
                }

                foreach (var item in type.GetFields(flags))
                {
                    if (item.FieldType == typeof(IServiceProvider))
                    {
                        serviceProvider = item.GetValue(context.Target) as IServiceProvider;
                        break;
                    }
                }

                if (serviceProvider is not null)
                {
                    return serviceProvider;
                }

                if (type.BaseType is null)
                {
                    return serviceProvider;
                }

                return GetServiceProvider(context, type.BaseType);
            }
            catch (Exception)
            {
                return serviceProvider;
            }
        }

    }
}
